{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pinecone-io/examples/blob/master/learn/analytics-and-ml/model-training/gpl/04-finetune.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/pinecone-io/examples/blob/master/learn/analytics-and-ml/model-training/gpl/04-finetune.ipynb)\n",
    "\n",
    "# Fine-tuning with MSEMargin Loss\n",
    "\n",
    "Now that we have our margin labeled *(Q, P<sup>+</sup>, P<sup>-</sup>)* pairs, we can begin fine-tuning a bi-encoder model with MSEMargin loss. We will start by defining a data loading function that uses the standard `InputExample` format of *sentence-transformers*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a4549bb0e64f495b91db5f6eea7a17f6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/200000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "200000"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tqdm.auto import tqdm\n",
    "from sentence_transformers import InputExample\n",
    "\n",
    "training_data = []\n",
    "\n",
    "with open('data/triplets_margin.tsv', 'r', encoding='utf-8') as fp:\n",
    "    lines = fp.read().split('\\n')\n",
    "# loop through each line and return InputExample\n",
    "for line in tqdm(lines):\n",
    "    q, p, n, margin = line.split('\\t')\n",
    "    training_data.append(InputExample(\n",
    "        texts=[q, p, n],\n",
    "        label=float(margin)\n",
    "    ))\n",
    "\n",
    "len(training_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We load these pairs into a generator `DataLoader`. Margin MSE works best with a large `batch_size`, the `32` used here is reasonable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "batch_size = 32\n",
    "\n",
    "loader = torch.utils.data.DataLoader(\n",
    "    training_data, batch_size=batch_size, shuffle=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we initialize a bi-encoder model that we will be fine-tuning using domain adaption."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SentenceTransformer(\n",
       "  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: DistilBertModel \n",
       "  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "\n",
    "model = SentenceTransformer('msmarco-distilbert-base-tas-b')\n",
    "model.max_seq_length = 256\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then initialize the Margin MSE loss function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import losses\n",
    "\n",
    "loss = losses.MarginMSELoss(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.7/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  FutureWarning,\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2055a68e819c4a2295d62029458f7c7e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9e6a24e3909645e299b2e5ec192b2ef8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Iteration:   0%|          | 0/6250 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "epochs = 1\n",
    "warmup_steps = int(len(loader) * epochs * 0.1)\n",
    "\n",
    "model.fit(\n",
    "    train_objectives=[(loader, loss)],\n",
    "    epochs=epochs,\n",
    "    warmup_steps=warmup_steps,\n",
    "    output_path='msmarco-distilbert-base-tas-b-covid',\n",
    "    show_progress_bar=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model is saved in the `msmarco-distilbert-base-tas-b-covid` directory."
   ]
  }
 ],
 "metadata": {
  "environment": {
   "kernel": "conda-root-py",
   "name": "common-cu110.m91",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/base-cu110:m91"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.7 (main, Sep 14 2022, 22:38:23) [Clang 14.0.0 (clang-1400.0.29.102)]"
  },
  "vscode": {
   "interpreter": {
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
